Conversation
There was a problem hiding this comment.
Pull request overview
Adds AWQ (AWQ→Marlin) quantized model support to MES by introducing a quantization config/method abstraction, wiring quantization into model layers, and triggering post-load weight repacking for Marlin kernels.
Changes:
- Add
--quantizationCLI flag and plumb quantization selection throughEngine -> MESConfig -> model layers. - Introduce quantization framework + AWQ→Marlin implementation (config, utilities, repack/apply path).
- Update linear layers and Qwen2/Qwen3 model components to accept and use optional
quant_config, and run post-load quant weight processing.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
ultils/loader.py |
Extends loader to optionally post-process quantized weights after loading. |
schemas/config.py |
Adds quantization selection + parsing of quantization_config into AWQMarlinConfig. |
openai_server/fast_api.py |
Adds CLI argument --quantization/-q and passes it into Engine. |
core/engine.py |
Accepts quantization arg and constructs MESConfig with it. |
core/gpu_worker.py |
Passes mes_config.quantization_config into load_model. |
models/qwen2.py |
Threads quant_config into attention + MLP linear layers. |
models/qwen3.py |
Threads quant_config into attention linear layers. |
layers/linear.py |
Adds optional quantization to Linear modules; uses quant method apply path when enabled. |
layers/quantization/base_config.py |
Defines base interfaces for quantization configs and linear quant methods. |
layers/quantization/awq_marlin.py |
Implements AWQ→Marlin repack + Marlin GEMM apply for quantized linear. |
layers/quantization/marlin_utils.py |
Adds packing/permutation/shape-check/workspace helpers for Marlin. |
layers/quantization/quant_utils.py |
Adds scalar type definitions used by quantization methods. |
layers/quantization/__init__.py |
Exposes quantization public API symbols. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def load_model(model: nn.Module, path: str, quant_config=None): | ||
| """加载模型权重,支持量化模型""" | ||
| packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) | ||
|
|
||
| # 1. 加载所有权重 |
There was a problem hiding this comment.
当前量化路径会让 load_model 在处理 packed_modules_mapping 时崩溃:该分支里使用 getattr(param, "weight_loader")(无默认值),但量化 Linear 创建的是 qweight/qzeros/scales 参数,并没有为这些 Parameter 挂载 weight_loader 属性;遇到如 q_proj.qweight / gate_proj.qweight 这类需要 packed_modules_mapping 重映射的权重名时会直接抛 AttributeError。建议在量化权重参数创建时为 qweight/qzeros/scales 设置合适的 weight_loader(并处理 shard_id),或调整 loader 在 packed 分支里对缺失 weight_loader 的情况做兼容处理(但仍需正确分片/拼接)。
| # 根据量化配置创建参数 | ||
| if self.quant_method is not None: | ||
| self.quant_method.create_weights( | ||
| layer=self, | ||
| input_size_per_partition=input_size, | ||
| output_partition_sizes=[output_size], | ||
| input_size=input_size, | ||
| output_size=output_size, | ||
| params_dtype=torch.float16, | ||
| ) | ||
| else: | ||
| self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size)) | ||
| self.weight.weight_loader = self.weight_loader | ||
|
|
There was a problem hiding this comment.
量化分支下调用 create_weights() 只注册了 qweight/qzeros/scales,但没有像非量化 weight 一样给这些 Parameter 绑定 weight_loader。这样在 load_model 里会走 default_weight_loader 直接 copy_(TP>1 时会因为缺少分片逻辑导致 shape mismatch),并且在 packed_modules_mapping 分支会因为缺少 weight_loader 直接异常。建议:为 qweight/qzeros/scales 分别实现并绑定专用 weight_loader(按 Column/Row 并行正确在 output/input 维度上分片),并确保 packed QKV / gate_up 的 shard_id 路径也能正确加载。
| def create_weights( | ||
| self, | ||
| layer: nn.Module, | ||
| input_size_per_partition: int, | ||
| output_partition_sizes: list[int], | ||
| input_size: int, | ||
| output_size: int, | ||
| params_dtype: torch.dtype, | ||
| **extra_weight_attrs, | ||
| ): | ||
| """创建量化权重参数""" | ||
| output_size_per_partition = sum(output_partition_sizes) | ||
|
|
||
| if self.quant_config.group_size != -1: | ||
| group_size = self.quant_config.group_size | ||
| else: | ||
| group_size = input_size | ||
|
|
||
| # 验证shape是否与Marlin兼容 | ||
| verify_marlin_supports_shape( | ||
| output_size_per_partition=output_size_per_partition, | ||
| input_size_per_partition=input_size_per_partition, | ||
| input_size=input_size, | ||
| group_size=group_size, | ||
| ) | ||
|
|
||
| num_groups = input_size_per_partition // group_size | ||
|
|
||
| # qweight: INT4权重打包到INT32 | ||
| qweight = nn.Parameter( | ||
| torch.empty( | ||
| input_size_per_partition, | ||
| output_size_per_partition // self.quant_config.pack_factor, | ||
| dtype=torch.int32, | ||
| ), | ||
| requires_grad=False, | ||
| ) | ||
|
|
||
| # qzeros: per-group zero-points | ||
| qzeros = nn.Parameter( | ||
| torch.empty( | ||
| num_groups, | ||
| output_size_per_partition // self.quant_config.pack_factor, | ||
| dtype=torch.int32, | ||
| ), | ||
| requires_grad=False, | ||
| ) | ||
|
|
||
| # scales: per-group缩放因子 | ||
| scales = nn.Parameter( | ||
| torch.empty( | ||
| num_groups, | ||
| output_size_per_partition, | ||
| dtype=params_dtype, | ||
| ), | ||
| requires_grad=False, | ||
| ) | ||
|
|
||
| layer.register_parameter("qweight", qweight) | ||
| layer.register_parameter("qzeros", qzeros) | ||
| layer.register_parameter("scales", scales) | ||
|
|
||
| layer.input_size_per_partition = input_size_per_partition | ||
| layer.output_size_per_partition = output_size_per_partition | ||
| layer.num_groups = num_groups | ||
|
|
There was a problem hiding this comment.
AWQMarlinLinearMethod.create_weights() 注册 qweight/qzeros/scales 后没有提供任何加载分片逻辑。对 ColumnParallelLinear,qweight/scales/qzeros 的“输出维度”在这些张量里通常是第 2 维(N 或 N/pack_factor),而不是 layer.tp_dim=0;对 RowParallelLinear,输入维度切分也需要在 qweight 的 K 维上做 narrow。若不实现专用 weight_loader,会在 tensor_parallel_size>1 时加载失败或得到错误权重。建议在这里根据 layer 类型/并行方式定义并挂载 qweight/qzeros/scales 的 weight_loader(以及支持 packed 的 shard_id)。
| # 从 model_config 加载量化配置 | ||
| self.quantization_config = self._load_quantization_config(model_config, quantization) | ||
|
|
||
| def _load_quantization_config(self, model_config, user_method: Optional[str]): | ||
| """从模型配置加载量化配置""" | ||
| if not user_method: | ||
| return None | ||
|
|
||
| # 从 config 对象获取量化配置 | ||
| quant_config = getattr(model_config, 'quantization_config', None) | ||
|
|
||
| if not quant_config: | ||
| raise ValueError( | ||
| f"指定了量化方法 '{user_method}'," | ||
| f"但模型配置中没有 quantization_config 字段。" | ||
| f"请确认模型是量化模型。" | ||
| ) | ||
|
|
||
| # quantization_config 可能是字典或对象 | ||
| if hasattr(quant_config, 'to_dict'): | ||
| quant_config = quant_config.to_dict() | ||
|
|
||
| # 获取模型配置中的量化方法 | ||
| file_method = quant_config.get("quant_method", "").lower() | ||
| user_method_lower = user_method.lower() | ||
|
|
||
| # 验证兼容性 | ||
| if user_method_lower == "awq_marlin" and file_method == "awq": | ||
| print(f"[MESConfig] 检测到 AWQ 模型,使用 Marlin kernel 加速") | ||
| elif user_method_lower != file_method: | ||
| raise ValueError( | ||
| f"量化方法不匹配:命令行指定 '{user_method_lower}'," | ||
| f"但模型配置是 '{file_method}'" | ||
| ) | ||
|
|
||
| # 创建量化配置对象 | ||
| if user_method_lower in ["awq", "awq_marlin"]: | ||
|
|
||
| bits = quant_config.get("bits", 4) | ||
| group_size = quant_config.get("group_size", 128) | ||
| zero_point = quant_config.get("zero_point", True) | ||
|
|
||
| print(f"[MESConfig] 量化方法: awq_marlin") | ||
| print(f"[MESConfig] 量化参数: bits={bits}, group_size={group_size}, zero_point={zero_point}") | ||
|
|
||
| return AWQMarlinConfig( | ||
| weight_bits=bits, | ||
| group_size=group_size, | ||
| zero_point=zero_point, | ||
| ) | ||
| else: | ||
| raise ValueError(f"不支持的量化方法: {user_method_lower},目前仅支持 awq/awq_marlin") | ||
|
|
There was a problem hiding this comment.
新增的量化配置解析与 AWQ→Marlin 后处理逻辑目前没有对应测试覆盖(例如:当 CLI 指定 --quantization=awq_marlin 时,MESConfig 能正确从 transformers config 解析 quantization_config,并在 load_model 后触发各层 quant_method.process_weights_after_loading)。仓库里已有 model/engine accuracy 测试,建议补充一个最小单测/集成测来覆盖该新路径(可通过构造最小的 fake model_config.quantization_config dict 来避免依赖真实量化权重)。
| marlin_permute_scales, | ||
| awq_to_marlin_zero_points, | ||
| ) | ||
| from layers.quantization.quant_utils import ScalarType, get_scalar_types |
There was a problem hiding this comment.
Import of 'ScalarType' is not used.
| from layers.quantization.quant_utils import ScalarType, get_scalar_types | |
| from layers.quantization.quant_utils import get_scalar_types |
No description provided.